import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import sys
import os

from advtr.data import train_loader, test_loader
from advtr.attacks import fgsm, pgd, pgd2, pgd_linf, rfgsm, rpgd, rpgd2, rpgd_linf
from advtr.train import epoch_adversarial
from advtr.model import model_gen

net, eps, alpha, seed = sys.argv[1], float(sys.argv[2]), float(sys.argv[3]), int(sys.argv[4])

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if not os.path.exists('./results/evals'):
    os.makedirs('./results/evals')

file_std = "./results/models/model_standard_net=%s_eps=%s_alpha=%s_seed=%s.pt"%(net,eps,alpha,seed)
model_std = model_gen(net).to(device)
model_std.load_state_dict(torch.load(file_std))
    
file_pgd = "./results/models/model_atk=pgd_net=%s_eps=%s_alpha=%s_seed=%s.pt"%(net,eps,alpha,seed)
model_pgd = model_gen(net).to(device)
model_pgd.load_state_dict(torch.load(file_pgd))

file_rpgd = "./results/models/model_atk=rpgd_net=%s_eps=%s_alpha=%s_seed=%s.pt"%(net,eps,alpha,seed)
model_rpgd = model_gen(net).to(device)
model_rpgd.load_state_dict(torch.load(file_rpgd))

file_pgd2 = "./results/models/model_atk=pgd2_net=%s_eps=%s_alpha=%s_seed=%s.pt"%(net,eps,alpha,seed)
model_pgd2 = model_gen(net).to(device)
model_pgd2.load_state_dict(torch.load(file_pgd2))

file_rpgd2 = "./results/models/model_atk=rpgd2_net=%s_eps=%s_alpha=%s_seed=%s.pt"%(net,eps,alpha,seed)
model_rpgd2 = model_gen(net).to(device)
model_rpgd2.load_state_dict(torch.load(file_rpgd2))

eval_std = "./results/evals/model_standard_net=%s_eps=%s_alpha=%s_seed=%s.pt"%(net,eps,alpha,seed)
eval_pgd = "./results/evals/model_atk=pgd_net=%s_eps=%s_alpha=%s_seed=%s.pt"%(net,eps,alpha,seed)
eval_rpgd = "./results/evals/model_atk=rpgd_net=%s_eps=%s_alpha=%s_seed=%s.pt"%(net,eps,alpha,seed)
eval_pgd2 = "./results/evals/model_atk=pgd2_net=%s_eps=%s_alpha=%s_seed=%s.pt"%(net,eps,alpha,seed)
eval_rpgd2 = "./results/evals/model_atk=rpgd2_net=%s_eps=%s_alpha=%s_seed=%s.pt"%(net,eps,alpha,seed)

print('Standard Model')
if not os.path.exists(eval_std):
    torch.manual_seed(seed)
    fgsm_err = epoch_adversarial(test_loader, model_std, fgsm, epsilon=eps)[0]
    print("FGSM Attack: ", fgsm_err)
    rfgsm_err = epoch_adversarial(test_loader, model_std, rfgsm, epsilon=eps)[0]
    print("R-FGSM Attack: ", rfgsm_err)

    pgd_err = epoch_adversarial(test_loader, model_std, pgd, epsilon=eps, alpha=alpha)[0]
    print("PGD Attack: ", pgd_err)
    rpgd_err = epoch_adversarial(test_loader, model_std, rpgd, epsilon=eps, alpha=alpha)[0]
    print("RPGD Attack: ", rpgd_err)

    pgd_linf_err = epoch_adversarial(test_loader, model_std, pgd_linf, epsilon=eps, alpha=alpha)[0]
    print("PGD_Linf Attack: ", pgd_linf_err)
    rpgd_linf_err = epoch_adversarial(test_loader, model_std, rpgd_linf, epsilon=eps, alpha=alpha)[0]
    print("RPGD_Linf Attack: ", rpgd_linf_err)

    pgd2_err = epoch_adversarial(test_loader, model_std, pgd2, epsilon=eps, alpha=alpha)[0]
    print("PGD2 Attack: ", pgd2_err, '\n')
    rpgd2_err = epoch_adversarial(test_loader, model_std, rpgd2, epsilon=eps, alpha=alpha)[0]
    print("RPGD2 Attack: ", rpgd2_err, '\n')

    std_errs = [fgsm_err, rfgsm_err, pgd_err, rpgd_err, pgd_linf_err, rpgd_linf_err, pgd2_err, rpgd2_err]

    torch.save(std_errs, eval_std)

print('PGD Model')
if not os.path.exists(eval_pgd):
    torch.manual_seed(seed)
    fgsm_err = epoch_adversarial(test_loader, model_pgd, fgsm, epsilon=eps)[0]
    print("FGSM Attack: ", fgsm_err)
    rfgsm_err = epoch_adversarial(test_loader, model_pgd, rfgsm, epsilon=eps)[0]
    print("RFGSM Attack: ", rfgsm_err)

    pgd_err = epoch_adversarial(test_loader, model_pgd, pgd, epsilon=eps, alpha=alpha)[0]
    print("PGD Attack: ", pgd_err)
    rpgd_err = epoch_adversarial(test_loader, model_pgd, rpgd, epsilon=eps, alpha=alpha)[0]
    print("RPGD Attack: ", rpgd_err)

    pgd_linf_err = epoch_adversarial(test_loader, model_pgd, pgd_linf, epsilon=eps, alpha=alpha)[0]
    print("PGD_Linf Attack: ", pgd_linf_err)
    rpgd_linf_err = epoch_adversarial(test_loader, model_pgd, rpgd_linf, epsilon=eps, alpha=alpha)[0]
    print("RPGD_Linf Attack: ", rpgd_linf_err)

    pgd2_err = epoch_adversarial(test_loader, model_pgd, pgd2, epsilon=eps, alpha=alpha)[0]
    print("PGD2 Attack: ", pgd2_err, '\n')
    rpgd2_err = epoch_adversarial(test_loader, model_pgd, rpgd2, epsilon=eps, alpha=alpha)[0]
    print("RPGD2 Attack: ", rpgd2_err, '\n')

    pgd_errs = [fgsm_err, rfgsm_err, pgd_err, rpgd_err, pgd_linf_err, rpgd_linf_err, pgd2_err, rpgd2_err]

    torch.save(pgd_errs, eval_pgd)

print('R-PGD Model')
if not os.path.exists(eval_rpgd):
    torch.manual_seed(seed)
    fgsm_err = epoch_adversarial(test_loader, model_rpgd, fgsm, epsilon=eps)[0]
    print("FGSM Attack: ", fgsm_err)
    rfgsm_err = epoch_adversarial(test_loader, model_rpgd, rfgsm, epsilon=eps)[0]
    print("RFGSM Attack: ", rfgsm_err)

    pgd_err = epoch_adversarial(test_loader, model_rpgd, pgd, epsilon=eps, alpha=alpha)[0]
    print("PGD Attack: ", pgd_err)
    rpgd_err = epoch_adversarial(test_loader, model_rpgd, rpgd, epsilon=eps, alpha=alpha)[0]
    print("RPGD Attack: ", rpgd_err)

    pgd_linf_err = epoch_adversarial(test_loader, model_rpgd, pgd_linf, epsilon=eps, alpha=alpha)[0]
    print("PGD_Linf Attack: ", pgd_linf_err)
    rpgd_linf_err = epoch_adversarial(test_loader, model_rpgd, rpgd_linf, epsilon=eps, alpha=alpha)[0]
    print("RPGD_Linf Attack: ", rpgd_linf_err)

    pgd2_err = epoch_adversarial(test_loader, model_rpgd, pgd2, epsilon=eps, alpha=alpha)[0]
    print("PGD2 Attack: ", pgd2_err, '\n')
    rpgd2_err = epoch_adversarial(test_loader, model_rpgd, rpgd2, epsilon=eps, alpha=alpha)[0]
    print("RPGD2 Attack: ", rpgd2_err, '\n')

    rpgd_errs = [fgsm_err, rfgsm_err, pgd_err, rpgd_err, pgd_linf_err, rpgd_linf_err, pgd2_err, rpgd2_err]

    torch.save(rpgd_errs, eval_rpgd)
    
print('PGD2 Model')
if not os.path.exists(eval_pgd2):
    torch.manual_seed(seed)
    fgsm_err = epoch_adversarial(test_loader, model_pgd2, fgsm, epsilon=eps)[0]
    print("FGSM Attack: ", fgsm_err)
    rfgsm_err = epoch_adversarial(test_loader, model_pgd2, rfgsm, epsilon=eps)[0]
    print("RFGSM Attack: ", rfgsm_err)

    pgd_err = epoch_adversarial(test_loader, model_pgd2, pgd, epsilon=eps, alpha=alpha)[0]
    print("PGD Attack: ", pgd_err)
    rpgd_err = epoch_adversarial(test_loader, model_pgd2, rpgd, epsilon=eps, alpha=alpha)[0]
    print("RPGD Attack: ", rpgd_err)

    pgd_linf_err = epoch_adversarial(test_loader, model_pgd2, pgd_linf, epsilon=eps, alpha=alpha)[0]
    print("PGD_Linf Attack: ", pgd_linf_err)
    rpgd_linf_err = epoch_adversarial(test_loader, model_pgd2, rpgd_linf, epsilon=eps, alpha=alpha)[0]
    print("RPGD_Linf Attack: ", rpgd_linf_err)

    pgd2_err = epoch_adversarial(test_loader, model_pgd2, pgd2, epsilon=eps, alpha=alpha)[0]
    print("PGD2 Attack: ", pgd2_err, '\n')
    rpgd2_err = epoch_adversarial(test_loader, model_pgd2, rpgd2, epsilon=eps, alpha=alpha)[0]
    print("RPGD2 Attack: ", rpgd2_err, '\n')

    pgd2_errs = [fgsm_err, rfgsm_err, pgd_err, rpgd_err, pgd_linf_err, rpgd_linf_err, pgd2_err, rpgd2_err]

    torch.save(pgd2_errs, eval_pgd2)

print('R-PGD2 Model')
if not os.path.exists(eval_rpgd2):
    torch.manual_seed(seed)
    fgsm_err = epoch_adversarial(test_loader, model_rpgd2, fgsm, epsilon=eps)[0]
    print("FGSM Attack: ", fgsm_err)
    rfgsm_err = epoch_adversarial(test_loader, model_rpgd2, rfgsm, epsilon=eps)[0]
    print("RFGSM Attack: ", rfgsm_err)

    pgd_err = epoch_adversarial(test_loader, model_rpgd2, pgd, epsilon=eps, alpha=alpha)[0]
    print("PGD Attack: ", pgd_err)
    rpgd_err = epoch_adversarial(test_loader, model_rpgd2, rpgd, epsilon=eps, alpha=alpha)[0]
    print("RPGD Attack: ", rpgd_err)

    pgd_linf_err = epoch_adversarial(test_loader, model_rpgd2, pgd_linf, epsilon=eps, alpha=alpha)[0]
    print("PGD_Linf Attack: ", pgd_linf_err)
    rpgd_linf_err = epoch_adversarial(test_loader, model_rpgd2, rpgd_linf, epsilon=eps, alpha=alpha)[0]
    print("RPGD_Linf Attack: ", rpgd_linf_err)

    pgd2_err = epoch_adversarial(test_loader, model_rpgd2, pgd2, epsilon=eps, alpha=alpha)[0]
    print("PGD2 Attack: ", pgd2_err, '\n')
    rpgd2_err = epoch_adversarial(test_loader, model_rpgd2, rpgd2, epsilon=eps, alpha=alpha)[0]
    print("RPGD2 Attack: ", rpgd2_err, '\n')

    rpgd2_errs = [fgsm_err, rfgsm_err, pgd_err, rpgd_err, pgd_linf_err, rpgd_linf_err, pgd2_err, rpgd2_err]

    torch.save(rpgd2_errs, eval_rpgd2)